/**@@@+++@@@@******************************************************************
**
** Microsoft Windows Media
** Copyright (C) Microsoft Corporation. All rights reserved.
**
***@@@---@@@@******************************************************************
*/

#include "blackboxtest.h"
#include "drmcrt.h"
#include "tLicGen.h"
#include "tOEMIMP.h"
#include "teststubcerts.h"
/*this header is needed for WINCE remove function*/
#include "tstutils.h"


DRM_BB_CONTEXT g_BBContext;
CLIENTID g_clientID;

/* #define MAX_GLOBAL_BINDING_BLOB 2 */
DRM_BINDING_INFO g_binding[MAX_GLOBAL_BINDING_BLOB];

#define MAX_BLOB 5
static DRM_BYTE *g_blob[MAX_BLOB];

#define MAX_CIPHEER 1
DRM_CIPHER_CONTEXT g_cipherContext[MAX_CIPHEER];

static void CleanKeyfileAndDevStore()
{
	#define KEYFILE_NAME "v2ks.bla"
	#define SECFILE_NAME "v2ks.sec"

	/*	Delete key file and secret file	*/
	DX_VOS_FDelete(KEYFILE_NAME);
	DX_VOS_FDelete(SECFILE_NAME);

	/* delete the secure store */
	RemoveDRMFile(RMFILE_STORE);
}

DRM_RESULT TestGetDevCertValues( DRM_BB_CONTEXT *pcontextBBX )
{
	DRM_RESULT dr;
	DRM_BYTE rgbBuffer[MAX_DEVICE_CERT_SIZE];    
	DRM_CONST_STRING dstrDevCert;

	DSTR_FROM_PB( &dstrDevCert, rgbBuffer, SIZEOF(rgbBuffer) );

	/*blackbox initialized */
	ChkDR( DRM_DDC_GetDeviceCertificate( (DRM_STRING*)&dstrDevCert, 0, &pcontextBBX->CryptoContext ) );
	ChkDR( DRM_DCP_LoadPropertiesCache( &dstrDevCert, &pcontextBBX->cachedCertValues, &pcontextBBX->CryptoContext ) );
	ChkDR( DRM_BBX_Initialize( pcontextBBX ) );

ErrorExit:
	return dr;
}

DRM_RESULT BB_PreTestCase(long lTCID, char *strTCName)
{
	int i;
	DRM_RESULT dr;
	const DRM_WCHAR devCertTemplate[] = { TWO_BYTES('d','e'), TWO_BYTES('v','c'),
									 	  TWO_BYTES('e','r'), TWO_BYTES('t','t'),
										  TWO_BYTES('e','m'), TWO_BYTES('p','l'),
										  TWO_BYTES('a','t'), TWO_BYTES('e','.'),
										  TWO_BYTES('d','a'), TWO_BYTES('t','\0') };
	const DRM_WCHAR priv[] = { TWO_BYTES('p','r'), TWO_BYTES('i','v'), TWO_BYTES('.','d'),
							   TWO_BYTES('a','t'), TWO_BYTES('\0',0) };
	
	DX_VOS_MemSet(g_binding, 0, sizeof(DRM_BINDING_INFO) * MAX_GLOBAL_BINDING_BLOB);
	
	/* Init random number blobs */
	for (i = 0; i < MAX_BLOB; i++)
		g_blob[i] = NULL;

	CleanKeyfileAndDevStore();
	ChkDR(SetDeviceEnv(devCertTemplate, priv, TRUE));

	/* get a devcert */
	ChkDR(TestGetDevCertValues(&g_BBContext));
	ChkDR(DRM_BBX_Initialize(&g_BBContext));
ErrorExit:
	return dr;
}

DRM_RESULT BB_PostTestCase(long lTCID, char *strTCName)
{
	int i;

	for (i = 0; i < MAX_BLOB; i++) {
		DX_VOS_MemFree(g_blob[i]);
		g_blob[i] = NULL;
	}

	return DRM_BBX_Shutdown(&g_BBContext);
}

DRM_RESULT Test_CreateBindingInfo(long argc, char **argv);
DRM_RESULT Test_API_CanBind(long argc, char **argv);

/*
	Test API DRM_BBX_GetClientId
	argv[0]: client id pointer: NULL, NORMAL
	argv[1]: blackbox context: NULL, NORMAL
*/
DRM_RESULT Test_API_GetClientID(long argc, char **argv)
{
	DRM_RESULT dr;
	DRM_BB_CONTEXT *pBBContext;
	CLIENTID *pClientID;

	ChkArg(argc == 2);
	pClientID = argv[0]? &g_clientID: NULL;
	pBBContext = argv[1]? &g_BBContext: NULL;

	ChkDR(DRM_BBX_GetClientId(pClientID, pBBContext));
ErrorExit:
	return dr;
}

/* Verify the global clientid is correctly encrypted */
DRM_RESULT Test_VerifyClientID(long argc, char **argv)
{
	DRM_RESULT dr;
	CLIENTID *pClientID = &g_clientID;
	
	/* Verify the version is correct */
	if (DX_VOS_MemCmp(pClientID->version, CLIENT_ID_VER, VERSION_LEN))
		ChkDR(DRM_E_FAIL);

	/* Decrypt the client id and verify its cert */
	ChkDR(tDecryptClientID(g_BBContext.CryptoContext.rgbCryptoContext, pClientID));	
	if (DX_VOS_MemCmp(&pClientID->pk, &g_BBContext.cachedCertValues.m_BBCompatibilityCert, sizeof(PKCERT)))
		dr = DRM_S_FALSE;
ErrorExit:
	return dr;
}




/* Generate random blob for other test functions 
	argv[0]: index to the global random blob to hold the generated number
	argv[1]: the length of the random bytes.
*/
DRM_RESULT Test_GenRandomBlob(long argc, char **argv)
{
	DRM_RESULT dr;
	int iRand;
	DRM_DWORD dwLen;

	ChkArg(argc == 2 && argv[0] && argv[1]);
	
	iRand = OEM_atoi(argv[0]);
	ChkArg(iRand >= 0 && iRand < MAX_BLOB);

	dwLen = OEM_atol(argv[1]);

	DX_VOS_MemFree(g_blob[iRand]);
	ChkMem(g_blob[iRand] = (DRM_BYTE*)DX_VOS_MemMalloc(dwLen * sizeof(DRM_BYTE)));
	
	ChkDR(OEM_GenRandomBytes(g_blob[iRand], dwLen));
ErrorExit:
	return dr;
}

/* Test API function DRM_BBX_HashValue
	argv[0]: index to the global random blob to hold the data to be hashed.
	argv[1]: the length of the input data.
	argv[2]: index to the global memory to save the hash value
	argv[3]: blackbox context: NULL, NORMAL
*/
DRM_RESULT Test_API_HashValue(long argc, char **argv)
{
	DRM_RESULT dr;
	int iRand, iHash;
	DRM_DWORD dwLen;
	DRM_BB_CONTEXT *pBBContext;

	ChkArg(argc == 4 && argv[0] && argv[1] && argv[2]);

	iRand = OEM_atoi(argv[0]);
	dwLen = OEM_atol(argv[1]);
	iHash = OEM_atoi(argv[2]);
	pBBContext = argv[3]? &g_BBContext: NULL;

	ChkArg(iRand >= 0 && iRand < MAX_BLOB && iHash >= 0 && iHash < MAX_BLOB && iRand != iHash);

	DX_VOS_MemFree(g_blob[iHash]);
	ChkMem(g_blob[iHash] = (DRM_BYTE*)DX_VOS_MemMalloc(SHA_DIGEST_LEN));
	ChkDR(DRM_BBX_HashValue(g_blob[iRand], dwLen, g_blob[iHash], pBBContext));
ErrorExit:
	return dr;
}

/* Perform memory operations on two blobs
	argv[0]: do what: memcmp, memcpy
	argv[1]: index to the first blob
	argv[2]: index to the second blob
	argv[3]: blob size
*/
DRM_RESULT Test_BlobOps(long argc, char **argv)
{
	DRM_RESULT dr = DRM_SUCCESS;
	int i1, i2;
	DRM_DWORD dwLen;
	
	ChkArg(argc == 4 && argv[0] && argv[1] && argv[2] && argv[3]);
	
	i1 = OEM_atoi(argv[1]);
	i2 = OEM_atoi(argv[2]);
	dwLen = OEM_atol(argv[3]);

	ChkArg(i1 >= 0 && i1 < MAX_BLOB && i2 >= 0 && i2 < MAX_BLOB);

	if (!DX_VOS_StrCmp(argv[0], "memcmp")) {
		if (DX_VOS_MemCmp(g_blob[i1], g_blob[i2], dwLen))
			dr = DRM_S_FALSE;
	} else if (!DX_VOS_StrCmp(argv[0], "memcpy"))
		DX_VOS_FastMemCpy(g_blob[i1], g_blob[i2], dwLen);
	else
		dr = DRM_E_INVALIDARG;

ErrorExit:
	return dr;
}

/* Encrypt a random blob with the blackbox public key.
	argv[0]: index to the random block to encrypt
	argv[1]: the size of the random block encrypt
	argv[2]: index to the output encrypted random block.
	argv[3]: symmetric key size in bytes
*/
DRM_RESULT Test_PKEncryptLarge(long argc, char **argv)
{
	DRM_RESULT dr;
	int iIn, iOut;
	DRM_DWORD cbIn, dwSymKeySize;

	ChkArg(argc == 4);

	iIn = OEM_atoi(argv[0]);
	cbIn = OEM_atol(argv[1]);
	iOut = OEM_atoi(argv[2]);
	dwSymKeySize = OEM_atol(argv[3]);

	ChkArg(iIn != iOut && iIn >= 0 && iIn < MAX_BLOB && iOut >= 0 && iOut < MAX_BLOB);

	DX_VOS_MemFree(g_blob[iOut]);
	ChkMem(g_blob[iOut] = (DRM_BYTE*)DX_VOS_MemMalloc(PK_ENC_CIPHERTEXT_LEN + cbIn));
    
	ChkDR(DRM_PK_EncryptLarge(&g_BBContext.cachedCertValues.m_BBCompatibilityCert.pk.pk, 
                               g_blob[iIn], 
                               cbIn, 
                               g_blob[iOut], 
                               dwSymKeySize, 
                              &g_BBContext.CryptoContext));
ErrorExit:
	return dr;
}

/* Test API DRM_BBX_DecryptLicense
	argv[0]: index to the global random blob as the input encrypted license.
	argv[1]: the input encrypted license length
	argv[2]: index to the global random blob to save the decrypted license. Or NULL to decrypt in-place.
	argv[3]: blackbox context: NULL, NORMAL
*/
DRM_RESULT Test_API_DecryptLicense(long argc, char **argv)
{
	DRM_RESULT dr;
	int iIn, iOut;
	DRM_DWORD cbIn;
	DRM_BYTE *pLicense = NULL;
	DRM_BB_CONTEXT *pBBContext;

	ChkArg(argc == 4);

	iIn = OEM_atoi(argv[0]);
	cbIn = OEM_atol(argv[1]);
	ChkArg(iIn >= 0 && iIn < MAX_BLOB);

	if (argv[2]) {
		iOut = OEM_atoi(argv[2]);
		ChkArg(iIn != iOut && iOut >= 0 && iOut < MAX_BLOB);
		
		DX_VOS_MemFree(g_blob[iOut]);
		ChkMem(pLicense = g_blob[iOut] = (DRM_BYTE*)DX_VOS_MemMalloc(cbIn));
	}

	pBBContext = argv[3]? &g_BBContext: NULL;

	if (DRM_BBX_DecryptLicense(g_blob[iIn], cbIn + PK_ENC_CIPHERTEXT_LEN, pLicense, pBBContext)) {
		dr = DRM_SUCCESS;
		if (!pLicense) /* in-place decryption, move the decrypted license up. */
			DRM_memmove(g_blob[iIn], g_blob[iIn] + PK_ENC_CIPHERTEXT_LEN, cbIn);
	} else
		dr = DRM_S_FALSE;
ErrorExit:
	return dr;
}

/* Test API DRM_BBX_CipherKeySetup 
	argv[0]: index to the binding info
	argv[1]: index to the cipher context to setup
	argv[2]: blackbox context: NULL, NORMAL
*/
DRM_RESULT Test_API_CipherKeySetup(long argc, char** argv)
{
	DRM_RESULT dr;
	int i;
	DRM_BINDING_INFO *pBinding = NULL;
	DRM_CIPHER_CONTEXT *pCipher;
	DRM_BB_CONTEXT *pBBContext;

	ChkArg(argc == 3);

	/* set up binding info parameter */
	if (argv[0]) {
		i = OEM_atoi(argv[0]);
		ChkArg(i >= 0 && i < MAX_GLOBAL_BINDING_BLOB);
		pBinding = &g_binding[i];
	}

	/* set up cipher parameter */
	if (argv[1]) {
		i = OEM_atoi(argv[1]);
		ChkArg(i >= 0 && i < MAX_CIPHEER);
		pCipher = &g_cipherContext[i];
	}

	/* set up blackbox context */
	pBBContext = argv[2]? &g_BBContext: NULL;

	ChkDR(DRM_BBX_CipherKeySetup(pBinding, 1, pCipher, pBBContext));
ErrorExit:
	return dr;
}

static DRM_RESULT i_API_EncryptOrDecrypt(long argc, char **argv, DRM_BOOL fEncrypt)
{
	DRM_RESULT dr;
	int i;
	DRM_DWORD cbData;
	DRM_BYTE *pbData = NULL;
	DRM_CIPHER_CONTEXT *pCipherContext = NULL;

	ChkArg(argc == 3 && argv[0]);

	cbData = OEM_atol(argv[0]);
	
	/* set up the pbData parameter */
	if (argv[1]) {
		i = OEM_atoi(argv[1]);
		ChkArg(i >= 0 && i < MAX_BLOB);
		pbData = g_blob[i];
	}

	/* set up the cipher context */
	if (argv[2]) {
		i = OEM_atoi(argv[2]);
		ChkArg(i >= 0 && i < MAX_CIPHEER);
		pCipherContext = &g_cipherContext[i];
	}

	if (fEncrypt) {
		ChkDR(DRM_BBX_Encrypt(cbData, pbData, pCipherContext));
	} else {
        ChkDR( DRM_BBX_InitDecrypt( pCipherContext, 
                                    (cbData < 15 ? pbData : pbData + cbData - 15),
                                    cbData ) );
		ChkDR(DRM_BBX_Decrypt(cbData, pbData, pCipherContext));
	}
ErrorExit:
	return dr;
}

/* test API DRM_BBX_Encrypt
	argv[0]: size of blob to encrypt
	argv[1]: index to the blob to encrypt
	argv[2]: index to the cipher context
*/
DRM_RESULT Test_API_Encrypt(long argc, char **argv)
{
	return i_API_EncryptOrDecrypt(argc, argv, TRUE);
}

/* test API DRM_BBX_Decrypt
	argv[0]: size of blob to decrypt
	argv[1]: index to the blob to decrypt
	argv[2]: index to the cipher context
*/
DRM_RESULT Test_API_Decrypt(long argc, char **argv)
{
	return i_API_EncryptOrDecrypt(argc, argv, FALSE);
}

DRM_RESULT Test_RenewBlackboxContext(long argc, char **argv)
{
	DRM_RESULT dr;

	ChkDR(DRM_BBX_Shutdown(&g_BBContext));
	ChkDR(TestGetDevCertValues(&g_BBContext));
	ChkDR(DRM_BBX_Initialize(&g_BBContext));
ErrorExit:
	return dr;
}

// IMPLEMENT_DEFAULT_WARPTEST
// 
// BEGIN_APIMAP(RefBlackBoxTest_ansi, "refbb")
// 	API_ENTRY(Test_API_GetClientID)
// 	API_ENTRY(Test_VerifyClientID)
// 	API_ENTRY(Test_CreateBindingInfo)
// 	API_ENTRY(Test_API_CanBind)
// 	API_ENTRY(Test_GenRandomBlob)
// 	API_ENTRY(Test_API_HashValue)
// 	API_ENTRY(Test_PKEncryptLarge)
// 	API_ENTRY(Test_API_DecryptLicense)
// 	API_ENTRY(Test_API_CipherKeySetup)
// 	API_ENTRY(Test_BlobOps)
// 	API_ENTRY(Test_API_Encrypt)
// 	API_ENTRY(Test_RenewBlackboxContext)
// 	API_ENTRY(Test_API_Decrypt)
// END_APIMAP

